﻿using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Configuration;
using PmSoft.Core;
using PmSoft.Web.Abstractions.ErrorCode;

namespace PmSoft.Web.Abstractions.Middlewares;

// 中间件类：实现请求频率限制
public class RateLimitMiddleware : IMiddleware
{
	private readonly IMemoryCache _cache;
	private readonly int _maxRequestsPerMinute; // 从配置中读取的最大请求次数
												// 滑动窗口时间（1分钟）
	private static readonly TimeSpan WindowDuration = TimeSpan.FromMinutes(1);

	public RateLimitMiddleware(IMemoryCache cache, IConfiguration configuration)
	{
		_cache = cache;
		// 从appsettings.json读取MaxRequestsPerMinute，默认为100
		_maxRequestsPerMinute = configuration.GetValue("RateLimit:MaxRequestsPerMinute", 100);
	}

	public async Task InvokeAsync(HttpContext context, RequestDelegate next)
	{
		// 获取客户端真实IP（支持反向代理）
		string clientIp = GetClientIpAddress(context);
		string cacheKey = $"RateLimit_{clientIp}";

		// 从缓存中获取或初始化访问记录
		if (!_cache.TryGetValue(cacheKey, out RateLimitInfo? rateInfo))
		{
			rateInfo = new RateLimitInfo
			{
				RequestCount = 0,
				WindowStart = DateTime.UtcNow
			};
		}

		// 检查滑动窗口是否需要重置
		if (DateTime.UtcNow - rateInfo.WindowStart > WindowDuration)
		{
			rateInfo.RequestCount = 0;
			rateInfo.WindowStart = DateTime.UtcNow;
		}

		// 检查请求次数是否超过限制
		if (rateInfo.RequestCount >= _maxRequestsPerMinute)
		{
			// 返回429 Too Many Requests状态码
			context.Response.StatusCode = ErrorCodeManager.GetError("REQRATELIMIT").HttpStatus;
			context.Response.ContentType = "application/json";
			await context.Response.WriteAsync(Json.Stringify(ApiResult.Error("REQRATELIMIT", "请求频率超过限制，请稍后再试")));
			return;
		}

		// 增加请求计数并更新缓存
		rateInfo.RequestCount++;
		_cache.Set(cacheKey, rateInfo, WindowDuration);

		// 继续处理请求
		await next(context);
	}

	// 获取客户端真实IP的辅助方法
	private static string GetClientIpAddress(HttpContext context)
	{
		// 优先从X-Forwarded-For头获取IP（适用于反向代理场景）
		string? forwardedFor = context.Request.Headers["X-Forwarded-For"].FirstOrDefault();
		if (!string.IsNullOrEmpty(forwardedFor))
		{
			// X-Forwarded-For可能包含多个IP（代理链），取第一个为客户端真实IP
			return forwardedFor.Split(',').First().Trim();
		}

		// 如果没有X-Forwarded-For头，则回退到RemoteIpAddress
		return context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
	}
}

// 访问记录的实体类
public class RateLimitInfo
{
	public int RequestCount { get; set; } // 请求计数
	public DateTime WindowStart { get; set; } // 窗口开始时间
}

// 扩展方法，便于在Startup或Program中使用
public static class RateLimitMiddlewareExtensions
{
	public static IApplicationBuilder UseRateLimit(this IApplicationBuilder builder)
	{
		return builder.UseMiddleware<RateLimitMiddleware>();
	}
}